{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import scipy.io\n",
    "from sklearn.metrics import classification_report,confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Please add the folder name of the dataset to run it on different dataset.\n",
    "dataset = 'CUB'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the .mat files extract all the features from resnet and the attribute splits. \n",
    "- The res101 contains features and the corresponding labels.\n",
    "- att_splits contains the different splits for trainval, train, val and test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "res101 = scipy.io.loadmat('/home/shrisha/Research/ZSL/dataset/xlsa17/data/'+dataset+'/res101.mat')\n",
    "att_splits = scipy.io.loadmat('/home/shrisha/Research/ZSL/dataset/xlsa17/data/'+dataset+'/att_splits.mat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['__header__', '__version__', '__globals__', 'image_files', 'features', 'labels'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res101.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['__header__', '__version__', '__globals__', 'allclasses_names', 'att', 'original_att', 'test_seen_loc', 'test_unseen_loc', 'train_loc', 'trainval_loc', 'val_loc'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "att_splits.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Using the correct naming conventions to get the loctions\n",
    "trainval_loc = 'trainval_loc'\n",
    "train_loc = 'train_loc'\n",
    "val_loc = 'val_loc'\n",
    "test_loc = 'test_unseen_loc'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need the corresponding ground-truth labels/classes for each training example for all our train, val, trainval and test set according to the split locations provided.\n",
    "In this example we have used the `CUB` dataset which has 200 unique classes overall."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = res101['labels']\n",
    "labels_train = labels[np.squeeze(att_splits[train_loc]-1)]\n",
    "labels_val = labels[np.squeeze(att_splits[val_loc]-1)]\n",
    "labels_trainval = labels[np.squeeze(att_splits[trainval_loc]-1)]\n",
    "labels_test = labels[np.squeeze(att_splits[test_loc]-1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151],\n",
       "       [151]], dtype=uint8)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels_train[:10,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,\n",
       "        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,\n",
       "        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,\n",
       "        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,\n",
       "        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,\n",
       "        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,\n",
       "        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,\n",
       "        92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,\n",
       "       105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n",
       "       118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,\n",
       "       131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,\n",
       "       144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,\n",
       "       157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,\n",
       "       170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182,\n",
       "       183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,\n",
       "       196, 197, 198, 199, 200], dtype=uint8)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unique_labels = np.unique(labels)\n",
    "unique_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In a typical zero-shot learning scenario, there are no overlapping classes between training and testing phase, i.e the train classes are completely different from the test classes. So let us verify if there are any overlapping classes in the test and train scenario.\n",
    "- During training phase we have `z` classes\n",
    "- During the testing phase we have `z'` classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels_seen = np.unique(labels_train)\n",
    "val_labels_unseen = np.unique(labels_val)\n",
    "trainval_labels_seen = np.unique(labels_trainval)\n",
    "test_labels_unseen = np.unique(labels_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of overlapping classes between train and val: 0\n",
      "Number of overlapping classes between trainval and test: 0\n"
     ]
    }
   ],
   "source": [
    "print(\"Number of overlapping classes between train and val:\",len(set(train_labels_seen).intersection(set(val_labels_unseen))))\n",
    "print(\"Number of overlapping classes between trainval and test:\",len(set(trainval_labels_seen).intersection(set(test_labels_unseen))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "for labels in train_labels_seen:\n",
    "    labels_train[labels_train == labels] = i    \n",
    "    i = i+1\n",
    "j = 0\n",
    "for labels in val_labels_unseen:\n",
    "    labels_val[labels_val == labels] = j\n",
    "    j = j+1\n",
    "k = 0\n",
    "for labels in trainval_labels_seen:\n",
    "    labels_trainval[labels_trainval == labels] = k\n",
    "    k = k+1\n",
    "l = 0\n",
    "for labels in test_labels_unseen:\n",
    "    labels_test[labels_test == labels] = l\n",
    "    l = l+1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us denote the features X ∈ [d×m] available at training stage, where d is the dimensionality\n",
    "of the data, and m is the number of instances. We are useing resnet features which are extracted from `CUB` dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_features = res101['features']\n",
    "train_vec = X_features[:,np.squeeze(att_splits[train_loc]-1)]\n",
    "val_vec = X_features[:,np.squeeze(att_splits[val_loc]-1)]\n",
    "trainval_vec = X_features[:,np.squeeze(att_splits[trainval_loc]-1)]\n",
    "test_vec = X_features[:,np.squeeze(att_splits[test_loc]-1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Features for train: (2048, 5875)\n",
      "Features for val: (2048, 2946)\n",
      "Features for trainval: (2048, 7057)\n",
      "Features for test: (2048, 2967)\n"
     ]
    }
   ],
   "source": [
    "print(\"Features for train:\", train_vec.shape)\n",
    "print(\"Features for val:\", val_vec.shape)\n",
    "print(\"Features for trainval:\", trainval_vec.shape)\n",
    "print(\"Features for test:\", test_vec.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Normalize the vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalization(vec,mean,std):\n",
    "    sol = vec - mean\n",
    "    sol1 = sol/std\n",
    "    return sol1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_mean = train_vec.mean(axis=1, keepdims=True)\n",
    "# train_std = np.std(train_vec, axis=1, keepdims = True)\n",
    "# trainval_mean = trainval_vec.mean(axis=1, keepdims = True)\n",
    "# trainval_std = np.std(trainval_vec, axis=1, keepdims=True)\n",
    "\n",
    "# train_vec = normalization(train_vec, train_mean, train_std)\n",
    "# val_vec = normalization(val_vec, train_mean, train_std)\n",
    "\n",
    "# trainval_vec = normalization(trainval_vec, trainval_mean, trainval_std)\n",
    "# test_vec = normalization(test_vec, trainval_mean, trainval_std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each of the classes in the dataset have an attribute (a) description. This vector is known as the `Signature matrix` of dimension S ∈ [0, 1]a×z. For training stage there are z classes and z' classes  for test S ∈ [0, 1]a×z'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Signature matrix\n",
    "signature = att_splits['att']\n",
    "train_sig = signature[:,(train_labels_seen)-1]\n",
    "val_sig = signature[:,(val_labels_unseen)-1]\n",
    "trainval_sig = signature[:,(trainval_labels_seen)-1]\n",
    "test_sig = signature[:,(test_labels_unseen)-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a signature matrix, where the occurance of an attribute corresponding to the class is give.\n",
    "For instance, if the classes are `horse` and `zebra` and the corresponding attributes are [wild_animal, 4_legged, carnivore]\n",
    "\n",
    "```\n",
    " Horse      Zebra\n",
    "[0.00354613 0.        ] Domestic_animal\n",
    "[0.13829921 0.20209503] 4_legged\n",
    "[0.06560347 0.04155225] carnivore\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.00354613 0.        ]\n",
      " [0.13829921 0.20209503]\n",
      " [0.06560347 0.04155225]]\n"
     ]
    }
   ],
   "source": [
    "print(train_sig[3:6,:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Signature for train: (312, 100)\n",
      "Signature for val: (312, 50)\n",
      "Signature for trainval: (312, 150)\n",
      "Signature for test: (312, 50)\n"
     ]
    }
   ],
   "source": [
    "print(\"Signature for train:\", train_sig.shape)\n",
    "print(\"Signature for val:\", val_sig.shape)\n",
    "print(\"Signature for trainval:\", trainval_sig.shape)\n",
    "print(\"Signature for test:\", test_sig.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "#params for train and val set\n",
    "m_train = labels_train.shape[0]\n",
    "n_val = labels_val.shape[0]\n",
    "z_train = len(train_labels_seen)\n",
    "z1_val = len(val_labels_unseen)\n",
    "\n",
    "#params for trainval and test set\n",
    "m_trainval = labels_trainval.shape[0]\n",
    "n_test = labels_test.shape[0]\n",
    "z_trainval = len(trainval_labels_seen)\n",
    "z1_test = len(test_labels_unseen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ground truth is a one-hot encoded vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "#ground truth for train and val set\n",
    "gt_train = 0*np.ones((m_train, z_train))\n",
    "gt_train[np.arange(m_train), np.squeeze(labels_train)] = 1\n",
    "\n",
    "#grountruth for trainval and test set\n",
    "gt_trainval = 0*np.ones((m_trainval, z_trainval))\n",
    "gt_trainval[np.arange(m_trainval), np.squeeze(labels_trainval)] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_train[:1,:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train set\n",
    "d_train = train_vec.shape[0]\n",
    "a_train = train_sig.shape[0]\n",
    "\n",
    "#Weights\n",
    "V = np.zeros((d_train,a_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#trainval set\n",
    "d_trainval = trainval_vec.shape[0]\n",
    "a_trainval = trainval_sig.shape[0]\n",
    "W = np.zeros((d_trainval,a_trainval))\n",
    "\n",
    "#Note: These hyper-parameters were found using the code snippet available below\n",
    "gamm1 = 3\n",
    "alph1 = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The one-line code solution proposed.\n",
    "```\n",
    "V = inverse(XX' + γI) XYS' inverse(SS' + λI)\n",
    "```\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_1_test = np.linalg.pinv(np.matmul(trainval_vec, trainval_vec.transpose()) + (10**alph1)*np.eye(d_trainval))\n",
    "part_0_test = np.matmul(np.matmul(trainval_vec,gt_trainval),trainval_sig.transpose())\n",
    "part_2_test = np.linalg.pinv(np.matmul(trainval_sig, trainval_sig.transpose()) + (10**gamm1)*np.eye(a_trainval))\n",
    "\n",
    "W = np.matmul(np.matmul(part_1_test,part_0_test),part_2_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For inference stage, \n",
    "```\n",
    "argmax(x'VS)\n",
    "```\n",
    "Where S is the signature matrix of the test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#predictions\n",
    "outputs_1 = np.matmul(np.matmul(test_vec.transpose(),W),test_sig)\n",
    "preds_1 = np.array([np.argmax(output) for output in outputs_1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The top 1% accuracy is: 40.1153158858462\n"
     ]
    }
   ],
   "source": [
    "cm = confusion_matrix(labels_test, preds_1)\n",
    "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "avg = sum(cm.diagonal())/len(test_labels_unseen)\n",
    "print(\"The top 1% accuracy is:\", avg*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "------------------------------------------------------------------------------------------------\n",
    "The below code snippet can be used to find the best hyper-parameter using the train and val set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 -1\n"
     ]
    }
   ],
   "source": [
    "accu = 0.10\n",
    "alph1 = 4\n",
    "gamm1 = 1\n",
    "for alpha in range(-3, 4):\n",
    "    for gamma in range(-3,4):\n",
    "        #One line solution\n",
    "        part_1 = np.linalg.pinv(np.matmul(train_vec, train_vec.transpose()) + (10**alpha)*np.eye(d_train))\n",
    "        part_0 = np.matmul(np.matmul(train_vec,gt_train),train_sig.transpose())\n",
    "        part_2 = np.linalg.pinv(np.matmul(train_sig, train_sig.transpose()) + (10**gamma)*np.eye(a_train))\n",
    "\n",
    "        V = np.matmul(np.matmul(part_1,part_0),part_2)\n",
    "        #print(V)\n",
    "\n",
    "        #predictions\n",
    "        outputs = np.matmul(np.matmul(val_vec.transpose(),V),val_sig)\n",
    "        preds = np.array([np.argmax(output) for output in outputs])\n",
    "\n",
    "        #print(accuracy_score(labels_val,preds))\n",
    "        cm = confusion_matrix(labels_val, preds)\n",
    "        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        avg = sum(cm.diagonal())/len(val_labels_unseen)\n",
    "        #print(\"Avg:\", avg, alpha, gamma)\n",
    "\n",
    "        if avg > accu:\n",
    "            accu = avg\n",
    "            alph1 = alpha\n",
    "            gamm1 = gamma\n",
    "print(alph1, gamm1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "work",
   "language": "python",
   "name": "work"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
